{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fwmDcxdJw3Nm"
},
"source": [
"# Variational AutoEncoder\n",
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)
\n",
"**Date created:** 2020/05/03
\n",
"**Last modified:** 2024/04/24
\n",
"**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0BJPNJFw3No"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"source": [
"#!pip install --upgrade keras\n"
],
"metadata": {
"id": "o5Tc7p-Xx-lm"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "tyl1_KpGw3No"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YYx6q4r2w3No"
},
"source": [
"## Create a sampling layer"
]
},
{
"cell_type": "code",
"source": [
"class Sampling(layers.Layer):\n",
" \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
"\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.seed_generator = keras.random.SeedGenerator(1337)\n",
"\n",
" def call(self, inputs):\n",
" z_mean, z_log_var = inputs\n",
" batch = ops.shape(z_mean)[0]\n",
" dim = ops.shape(z_mean)[1]\n",
" epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)\n",
" return z_mean + ops.exp(0.5 * z_log_var) * epsilon"
],
"metadata": {
"id": "RsoBWk6VHgHW"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"# latent_dim = 2\n",
"\n",
"# encoder_inputs = keras.Input(shape=(200, 200, 3))\n",
"# x = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")(encoder_inputs)\n",
"# x = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"# x = layers.Flatten()(x)\n",
"# x = layers.Dense(16, activation=\"relu\")(x)\n",
"# z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
"# z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
"# z = Sampling()([z_mean, z_log_var])\n",
"# encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
"# encoder.summary()"
],
"metadata": {
"id": "_5tf0AGGHjcU"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# yyz"
],
"metadata": {
"id": "Xr93wqVOIrH3"
}
},
{
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"\n",
"class Sampling(layers.Layer):\n",
" \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
" def call(self, inputs):\n",
" z_mean, z_log_var = inputs\n",
" batch = tf.shape(z_mean)[0]\n",
" dim = tf.shape(z_mean)[1]\n",
" epsilon = tf.random.normal(shape=(batch, dim))\n",
" return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
"\n",
"# Adjust the encoder\n",
"latent_dim = 2\n",
"encoder_inputs = keras.Input(shape=(200, 200, 3))\n",
"x = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")(encoder_inputs)\n",
"x = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"x = layers.Conv2D(128, 3, activation=\"relu\", strides=2, padding=\"same\")(x) # Added extra convolution layer\n",
"x = layers.Flatten()(x)\n",
"x = layers.Dense(256, activation=\"relu\")(x) # Increased the number of neurons\n",
"z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
"z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
"z = Sampling()([z_mean, z_log_var])\n",
"encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
"\n",
"# Adjust the decoder\n",
"latent_inputs = keras.Input(shape=(latent_dim,))\n",
"x = layers.Dense(25 * 25 * 128, activation=\"relu\")(latent_inputs) # Adjust to appropriate starting shape\n",
"x = layers.Reshape((25, 25, 128))(x)\n",
"x = layers.Conv2DTranspose(128, 3, activation=\"relu\", strides=2, padding=\"same\")(x) # 50x50\n",
"x = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x) # 100x100\n",
"x = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")(x) # 200x200\n",
"decoder_outputs = layers.Conv2DTranspose(3, 3, activation=\"sigmoid\", padding=\"same\")(x)\n",
"decoder = keras.Model(latent_inputs, decoder_outputs, name=\"decoder\")\n",
"\n",
"# Print summaries to verify model architecture\n",
"encoder.summary()\n",
"decoder.summary()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 774
},
"id": "xMfAWTkzIpfa",
"outputId": "e78cd2ca-ac37-4806-8e1a-9b73d8fda999"
},
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mModel: \"encoder\"\u001b[0m\n"
],
"text/html": [
"
Model: \"encoder\"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m200\u001b[0m, \u001b[38;5;34m200\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m896\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │ conv2d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_2 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m25\u001b[0m, \u001b[38;5;34m25\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m73,856\u001b[0m │ conv2d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ flatten (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m80000\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv2d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m20,480,256\u001b[0m │ flatten[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ z_mean (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m514\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ z_log_var (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m514\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ sampling (\u001b[38;5;33mSampling\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ z_mean[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ │ │ z_log_var[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n"
],
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer (InputLayer) │ (None, 200, 200, 3) │ 0 │ - │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d (Conv2D) │ (None, 100, 100, 32) │ 896 │ input_layer[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_1 (Conv2D) │ (None, 50, 50, 64) │ 18,496 │ conv2d[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ conv2d_2 (Conv2D) │ (None, 25, 25, 128) │ 73,856 │ conv2d_1[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ flatten (Flatten) │ (None, 80000) │ 0 │ conv2d_2[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ dense (Dense) │ (None, 256) │ 20,480,256 │ flatten[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ z_mean (Dense) │ (None, 2) │ 514 │ dense[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ z_log_var (Dense) │ (None, 2) │ 514 │ dense[0][0] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ sampling (Sampling) │ (None, 2) │ 0 │ z_mean[0][0], │\n",
"│ │ │ │ z_log_var[0][0] │\n",
"└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m20,574,532\u001b[0m (78.49 MB)\n"
],
"text/html": [
"Total params: 20,574,532 (78.49 MB)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m20,574,532\u001b[0m (78.49 MB)\n" ], "text/html": [ "
Trainable params: 20,574,532 (78.49 MB)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ], "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1mModel: \"decoder\"\u001b[0m\n" ], "text/html": [ "
Model: \"decoder\"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer_1 (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m80000\u001b[0m) │ \u001b[38;5;34m240,000\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ reshape (\u001b[38;5;33mReshape\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m25\u001b[0m, \u001b[38;5;34m25\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose (\u001b[38;5;33mConv2DTranspose\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m147,584\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose_1 (\u001b[38;5;33mConv2DTranspose\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m73,792\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose_2 (\u001b[38;5;33mConv2DTranspose\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m200\u001b[0m, \u001b[38;5;34m200\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m18,464\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose_3 (\u001b[38;5;33mConv2DTranspose\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m200\u001b[0m, \u001b[38;5;34m200\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m867\u001b[0m │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n"
],
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer_1 (InputLayer) │ (None, 2) │ 0 │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense_1 (Dense) │ (None, 80000) │ 240,000 │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ reshape (Reshape) │ (None, 25, 25, 128) │ 0 │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose (Conv2DTranspose) │ (None, 50, 50, 128) │ 147,584 │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose_1 (Conv2DTranspose) │ (None, 100, 100, 64) │ 73,792 │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose_2 (Conv2DTranspose) │ (None, 200, 200, 32) │ 18,464 │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_transpose_3 (Conv2DTranspose) │ (None, 200, 200, 3) │ 867 │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m480,707\u001b[0m (1.83 MB)\n"
],
"text/html": [
"Total params: 480,707 (1.83 MB)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m480,707\u001b[0m (1.83 MB)\n" ], "text/html": [ "
Trainable params: 480,707 (1.83 MB)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ], "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "class VAE(keras.Model):\n", " def __init__(self, encoder, decoder, **kwargs):\n", " super(VAE, self).__init__(**kwargs)\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.total_loss_tracker = keras.metrics.Mean(name=\"total_loss\")\n", " self.reconstruction_loss_tracker = keras.metrics.Mean(name=\"reconstruction_loss\")\n", " self.kl_loss_tracker = keras.metrics.Mean(name=\"kl_loss\")\n", "\n", " @property\n", " def metrics(self):\n", " return [\n", " self.total_loss_tracker,\n", " self.reconstruction_loss_tracker,\n", " self.kl_loss_tracker,\n", " ]\n", "\n", " def train_step(self, data):\n", " with tf.GradientTape() as tape:\n", " z_mean, z_log_var, z = self.encoder(data)\n", " reconstruction = self.decoder(z)\n", " # Assuming your data is normalized to be between [0,1], using binary crossentropy\n", " reconstruction_loss = tf.reduce_mean(\n", " tf.reduce_sum(\n", " keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)\n", " )\n", " )\n", " kl_loss = -0.5 * tf.reduce_mean(\n", " tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)\n", " )\n", " total_loss = reconstruction_loss + kl_loss\n", "\n", " grads = tape.gradient(total_loss, self.trainable_weights)\n", " self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n", "\n", " # Update metrics\n", " self.total_loss_tracker.update_state(total_loss)\n", " self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n", " self.kl_loss_tracker.update_state(kl_loss)\n", "\n", " return {\n", " \"loss\": self.total_loss_tracker.result(),\n", " \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n", " \"kl_loss\": self.kl_loss_tracker.result(),\n", " }\n" ], "metadata": { "id": "yEBi1DFpIphi" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "id": "tngtOMDDK9Y7", "outputId": "6195013f-d04d-4db8-bfd1-ff4ba4233cf8" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ] }, { "cell_type": "code", "source": [ "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# Function to load and preprocess images\n", "def load_image(path, size=(200, 200)):\n", " image = Image.open(path).resize(size)\n", " image = np.array(image) / 255.0 # Normalize to [0, 1]\n", " return image\n", "\n", "image1_path = '/content/drive/MyDrive/machine_learning/vae/pic1.jpg'\n", "image2_path = '/content/drive/MyDrive/machine_learning/vae/pic2.jpg'\n", "\n", "# Load images with preprocessing\n", "pic_1 = load_image(image1_path)\n", "pic_2 = load_image(image2_path)\n", "\n", "# Expand dimensions to add batch size of 1 for model input\n", "pic_1_batch = np.expand_dims(pic_1, axis=0)\n", "pic_2_batch = np.expand_dims(pic_2, axis=0)\n" ], "metadata": { "id": "H43mu71tMKh3" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "# Display the images\n", "plt.figure(figsize=(10, 5))\n", "plt.subplot(1, 2, 1)\n", "plt.imshow(pic_1)\n", "plt.title('Picture 1 - Resized')\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.imshow(pic_2)\n", "plt.title('Picture 2 - Resized')\n", "\n", "plt.show()\n", "\n", "# Print image properties to verify\n", "print(f\"Picture 1 - Shape: {pic_1.shape}, Dtype: {pic_1.dtype}, Min: {pic_1.min()}, Max: {pic_1.max()}\")\n", "print(f\"Picture 2 - Shape: {pic_2.shape}, Dtype: {pic_2.dtype}, Min: {pic_2.min()}, Max: {pic_2.max()}\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 470 }, "id": "AT1-hewmMM1Q", "outputId": "494aa5c8-540e-41f6-f878-f0630968f02d" }, "execution_count": 9, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "